{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from notebook_utils import *"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Select your checkpoint"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "path = pyrootutils.find_root(indicator=\".project-root\")\n",
                "run_path = Path(\"logs/PointBeV/effb4/vis1_r3/\")\n",
                "ckpt_path = path / run_path / \"checkpoints\" / \"38_69.ckpt\"\n",
                "config_path = \"../\" / run_path / \".hydra\"\n",
                "overrides_path = config_path / \"overrides.yaml\"\n",
                "device = \"cuda\""
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Load overrides"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "overrides = OmegaConf.load(overrides_path)\n",
                "overrides = [v for v in overrides if len(v.split(\"/\")) < 2]"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Create the associated configuration file."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "with initialize(version_base=\"1.3\", config_path=str(config_path)):\n",
                "    cfg = compose(\n",
                "        config_name=\"config.yaml\",\n",
                "        return_hydra_config=True,\n",
                "        overrides=overrides\n",
                "        + [\n",
                "            \"data.version=trainval\",\n",
                "            \"data.batch_size=1\",\n",
                "            \"data.valid_batch_size=1\",\n",
                "            \n",
                "            # # Sparse evaluation\n",
                "            # \"model.net.sampled_kwargs.val_mode=regular_pillars\",\n",
                "            # \"model.net.sampled_kwargs.patch_size=1\",\n",
                "            # \"model.net.sampled_kwargs.valid_fine=True\",\n",
                "            # \"model.net.sampled_kwargs.N_coarse=2000\",\n",
                "            # \"model.net.sampled_kwargs.N_fine=dyna\",\n",
                "            # \"model.net.sampled_kwargs.N_anchor=dyna\",\n",
                "            # \"model.net.sampled_kwargs.fine_thresh=0.1\",\n",
                "            # \"model.net.sampled_kwargs.fine_patch_size=9\",\n",
                "        ],\n",
                "    )\n",
                "\n",
                "    cfg.paths.root_dir = str(pyrootutils.find_root(indicator=\".project-root\"))\n",
                "cfg.ckpt.path = ckpt_path"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Create model"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "model = hydra.utils.instantiate(cfg.model)\n",
                "ckpt = utils.get_ckpt_from_path(cfg.ckpt.path)\n",
                "model = utils.load_state_model(\n",
                "    model,\n",
                "    ckpt,\n",
                "    cfg.ckpt.model.freeze,\n",
                "    cfg.ckpt.model.load,\n",
                "    verbose=1,\n",
                ")\n",
                "model.to(device)\n",
                "model.eval()\n",
                "self = model.net"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "datamodule = hydra.utils.instantiate(cfg.data)\n",
                "datamodule.val_shuffle = True\n",
                "datamodule.setup()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Sampled data"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "data = next(iter(datamodule.val_dataloader()))\n",
                "data = datamodule.on_after_batch_transfer(data, None)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "imgs = data[\"imgs\"].to(device)\n",
                "rots = data[\"rots\"].to(device)\n",
                "trans = data[\"trans\"].to(device)\n",
                "intrins = data[\"intrins\"].to(device)\n",
                "bev_aug = data[\"bev_aug\"].to(device)\n",
                "egoTin_to_seq = data[\"egoTin_to_seq\"].to(device)\n",
                "egoTout_to_seq = data[\"egoTout_to_seq\"].to(device)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Inference"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "with torch.inference_mode():\n",
                "    with torch.no_grad():\n",
                "        out = self(imgs, rots, trans, intrins, bev_aug, egoTin_to_seq)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Visualisation"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "b_ts = 0\n",
                "t_ts = 0\n",
                "cmap = \"Blues\"\n",
                "\n",
                "\n",
                "def visualise(imgs, data, out):\n",
                "    key = \"binimg\"\n",
                "    visibility = data.get(\"visibility\", torch.ones_like(data[key]))\n",
                "    \n",
                "    # Prepare the images\n",
                "    cam_imgs = [utils.imgs.DENORMALIZE_IMG(imgs[b_ts, t_ts, cam]) for cam in range(6)]\n",
                "    bev_imgs = torch.cat(\n",
                "        [\n",
                "            # Ground truth\n",
                "            (\n",
                "                data[key]\n",
                "                * ((visibility >= 2) * 2 + (visibility >= 1))\n",
                "            )[b_ts, t_ts],\n",
                "            # Prediction\n",
                "            out[\"bev\"][key][b_ts, t_ts].detach().cpu().sigmoid(),\n",
                "            out[\"masks\"][\"bev\"][key][b_ts, t_ts].detach().cpu(),\n",
                "        ]\n",
                "    )\n",
                "    \n",
                "    # Modify to keep the same color map:\n",
                "    bev_imgs[1][0,0] = 0\n",
                "    bev_imgs[1][-1,-1] = 1\n",
                "\n",
                "    # Create figure and axes\n",
                "    fig = plt.figure(figsize=(4.2 * 5, 1 * 5))  # Adjust figsize as needed\n",
                "\n",
                "    num_cols = 3 + 6\n",
                "    gs = gridspec.GridSpec(2, num_cols, figure=fig, wspace=0.0, hspace=0.0)\n",
                "\n",
                "    # Create the axis\n",
                "    axs = [fig.add_subplot(gs[i, j]) for i in range(2) for j in range(num_cols - 6)]\n",
                "\n",
                "    # Cameras\n",
                "    for idx, img in enumerate(cam_imgs):\n",
                "        axs[idx].imshow(img)\n",
                "        axs[idx].set_xticks([])\n",
                "        axs[idx].set_yticks([])\n",
                "    axs[1].set_title(\"Cameras\", fontsize=20)\n",
                "\n",
                "    # BeV\n",
                "    ax_gt = fig.add_subplot(gs[0:2, 3:5])\n",
                "    ax_gt.imshow(bev_imgs[0], cmap=cmap)\n",
                "\n",
                "    ax_pred = fig.add_subplot(gs[0:2, 5:7])\n",
                "    ax_pred.imshow(bev_imgs[1], cmap=cmap)\n",
                "    \n",
                "    ax_mask = fig.add_subplot(gs[0:2, 7:9])\n",
                "    ax_mask.imshow(bev_imgs[2], cmap=matplotlib.cm.Purples)\n",
                "    \n",
                "\n",
                "    for ax, title in zip([ax_gt, ax_pred, ax_mask], [\"Ground truth\", \"Prediction\", \"Mask\"]):\n",
                "        ax.set_xticks([])\n",
                "        ax.set_yticks([])\n",
                "        ax.set_title(title, fontsize=20)\n",
                "\n",
                "    return fig\n",
                "\n",
                "\n",
                "# Show the plot\n",
                "fig = visualise(imgs, data, out)\n",
                "fig.show()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Create a clip"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Transform to convert PyTorch tensor to PIL image\n",
                "transform = transforms.ToPILImage()\n",
                "\n",
                "# Generating frames using the model\n",
                "iter_loader = iter(datamodule.val_dataloader())\n",
                "\n",
                "max_frames = 100"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "frames = []  # List to hold frames\n",
                "frame_filenames = []\n",
                "\n",
                "for frame_number in trange(max_frames):  # Iterate over your data\n",
                "    input_data = next(iter_loader)\n",
                "    data = datamodule.on_after_batch_transfer(input_data, None)\n",
                "\n",
                "    imgs = data[\"imgs\"].to(device)\n",
                "    rots = data[\"rots\"].to(device)\n",
                "    trans = data[\"trans\"].to(device)\n",
                "    intrins = data[\"intrins\"].to(device)\n",
                "    bev_aug = data[\"bev_aug\"].to(device)\n",
                "    egoTin_to_seq = data[\"egoTin_to_seq\"].to(device)\n",
                "    egoTout_to_seq = data[\"egoTout_to_seq\"].to(device)\n",
                "\n",
                "    with torch.inference_mode():\n",
                "        with torch.no_grad():\n",
                "            out = self(imgs, rots, trans, intrins, bev_aug, egoTin_to_seq)\n",
                "\n",
                "    fig = visualise(imgs, data, out)\n",
                "    frame_filename = f\"frame_{frame_number}.png\"\n",
                "    plt.savefig(frame_filename, bbox_inches=\"tight\")\n",
                "    frame_filenames.append(frame_filename)\n",
                "    plt.close()\n",
                "\n",
                "# Create GIF from frames\n",
                "frames = [Image.open(image) for image in frame_filenames]\n",
                "frames[0].save(\n",
                "    \"../notebooks/test.gif\",\n",
                "    format=\"GIF\",\n",
                "    append_images=frames[1:],\n",
                "    save_all=True,\n",
                "    duration=1000,\n",
                "    loop=0,\n",
                ")\n",
                "\n",
                "# Optional: Clean up by removing individual frame files\n",
                "for filename in frame_filenames:\n",
                "    os.remove(filename)"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.8"
        },
        "orig_nbformat": 4
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
